/**
 * Copyright 2019-2022 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "compatible/model_compatibility_check.h"

#include "model_builder/model_builder_types.h"

#include "framework/infra/log/log.h"
#include "base/common/cl_manager/ops_kernel_store_manager.h"

#include "framework/graph/core/cgraph/graph_list_walker.h"

#include "framework/graph/core/node/node_spec.h"

#include "infra/base/assertion.h"

namespace ge {
namespace {
Status CheckOpSupported(OpNameTypeMap& directNodeNameType, OpNameDevicesMap& opDeviceSupport)
{
    for (const auto& item : directNodeNameType) {
        if (!IsOpTypeInMainGraph(item.second)) {
            if (opDeviceSupport.find(item.first) == opDeviceSupport.end()) {
                FMK_LOGE("Node %s type %s don't support!", item.first.c_str(), item.second.c_str());
                return hiai::FAILURE;
            }
        }
    }

    return hiai::SUCCESS;
}

hiai::ExecuteDevice FindOpSupportDeviceType(const hiai::FallBackMode& fallBackMode,
    const hiai::ExecuteDevice deviceType, const vector<hiai::ExecuteDevice>& supportList)
{
    if (std::find(supportList.begin(), supportList.end(), deviceType) != supportList.end()) {
        return deviceType;
    }
    if (fallBackMode == hiai::FallBackMode::ENABLE) {
        return std::find(supportList.begin(), supportList.end(), hiai::ExecuteDevice::CPU) != supportList.end() ?
            hiai::ExecuteDevice::CPU :
            hiai::ExecuteDevice::DEVICE_TYPE_RESERVED;
    }
    return hiai::ExecuteDevice::DEVICE_TYPE_RESERVED;
}

Status BindSpecificCl(OpNameTypeMap& directNodeNameType, const hiai::ExecuteDevice deviceType,
    const OpNameDevicesPair& pair, const hiai::FallBackMode fallBackMode, OpNameDevicesMap& chkResult)
{
    auto opMapIter = directNodeNameType.find(pair.first);
    if (opMapIter != directNodeNameType.end() && IsOpTypeInMainGraph(opMapIter->second)) {
        chkResult.insert({pair.first, vector<hiai::ExecuteDevice> {deviceType}});
        return hiai::SUCCESS;
    } else {
        hiai::ExecuteDevice opDevice = FindOpSupportDeviceType(fallBackMode, deviceType, pair.second);
        if (opDevice != hiai::ExecuteDevice::DEVICE_TYPE_RESERVED) {
            chkResult.insert({pair.first, vector<hiai::ExecuteDevice> {opDevice}});
            return hiai::SUCCESS;
        }
        FMK_LOGW("op %s device type %d is not supported in chkResult", pair.first.c_str(), deviceType);
        return hiai::FAILURE;
    }
}

Status GetCheckResultByModelDevice(const OpNameDevicesMap& opDeviceSupport, const hiai::FallBackMode& fallBackMode,
    const hiai::ExecuteDevice deviceType, OpNameTypeMap& directNodeNameType, OpNameDevicesMap& chkResult)
{
    OpNameDevicesMap tmpResult;
    for (const auto& item : opDeviceSupport) {
        HIAI_EXPECT_EXEC(BindSpecificCl(directNodeNameType, deviceType, item, fallBackMode, tmpResult));
    }
    chkResult.swap(tmpResult);

    return hiai::SUCCESS;
}

Status GetCheckResultByModelDeviceOrder(const OpNameDevicesMap& opDeviceSupport, OpNameTypeMap& directNodeNameType,
    const hiai::ModelCompileOptions& option, OpNameDevicesMap& chkResult)
{
    for (const auto& deviceType : option.modelDeviceConfig.modelDeviceOrder) {
        // WARNING[begin]: These log used for SUT test. Must confirm with tester before modification.
        FMK_LOGI("deviceType: %d,fallBackMode: %d", deviceType, option.modelDeviceConfig.fallBackMode);
        // WARNING[end]

        if (GetCheckResultByModelDevice(opDeviceSupport, option.modelDeviceConfig.fallBackMode,
            static_cast<hiai::ExecuteDevice>(deviceType), directNodeNameType, chkResult) == hiai::SUCCESS) {
            return hiai::SUCCESS;
        }
    }

    return hiai::FAILURE;
}

hiai::ExecuteDevice GetDefaultOpSupportDeviceType(const vector<hiai::ExecuteDevice>& supportList)
{
    vector<hiai::ExecuteDevice> deviceType = {
        hiai::ExecuteDevice::NPU, hiai::ExecuteDevice::CPU, hiai::ExecuteDevice::GPU};
    for (auto device : deviceType) {
        if (std::find(supportList.begin(), supportList.end(), device) != supportList.end()) {
            return device;
        }
    }
    return hiai::ExecuteDevice::CPU;
}

Status GetCheckResultByOpDeviceOrder(const OpNameDevicesMap& opDeviceSupport, OpNameTypeMap& directNodeNameType,
    const hiai::ModelCompileOptions& option, OpNameDevicesMap& chkResult)
{
    for (const auto& item : opDeviceSupport) {
        if (option.modelDeviceConfig.opDeviceOrder.find(item.first) == option.modelDeviceConfig.opDeviceOrder.end()) {
            chkResult.insert({item.first, vector<hiai::ExecuteDevice> {GetDefaultOpSupportDeviceType(item.second)}});
            continue;
        }
        auto iter = option.modelDeviceConfig.opDeviceOrder.find(item.first);
        const vector<hiai::ExecuteDevice>& opDeviceConfigs = iter->second;
        HIAI_EXPECT_TRUE_R(opDeviceConfigs.size() == 1, hiai::FAILURE);
        HIAI_EXPECT_EXEC(BindSpecificCl(directNodeNameType, opDeviceConfigs[0],
            item, static_cast<hiai::FallBackMode>(option.modelDeviceConfig.fallBackMode), chkResult));
    }

    return SUCCESS;
}

Status GetCheckResultByDeviceOrder(OpNameDevicesMap& opDeviceSupport, const hiai::ModelCompileOptions& option,
    OpNameTypeMap& directNodeNameType, OpNameDevicesMap& chkResult)
{
    if (!option.modelDeviceConfig.modelDeviceOrder.empty()) {
        return GetCheckResultByModelDeviceOrder(opDeviceSupport, directNodeNameType, option, chkResult);
    }
    return GetCheckResultByOpDeviceOrder(opDeviceSupport, directNodeNameType, option, chkResult);
}

Status TryRollBackToCpu(
    const hiai::FallBackMode& fallBackMode, OpNameDevicesMap& opDeviceSupport, OpNameDevicesMap& chkResult)
{
    if (fallBackMode == hiai::FallBackMode::ENABLE) {
        FMK_LOGI("roll back to cpu and check again");
        for (const auto& item : opDeviceSupport) {
            if (std::find(item.second.begin(), item.second.end(), hiai::ExecuteDevice::CPU) == item.second.end()) {
                return hiai::FAILURE;
            }
            chkResult.insert({item.first, vector<hiai::ExecuteDevice> {hiai::ExecuteDevice::CPU}});
        }
        return hiai::SUCCESS;
    }

    return hiai::FAILURE;
}

Status GetCheckResult(OpNameDevicesMap& opDeviceSupport, const hiai::ModelCompileOptions& option,
    OpNameTypeMap& directNodeNameType, bool& isRollBackCpu, OpNameDevicesMap& chkResult)
{
    HIAI_EXPECT_EXEC(CheckOpSupported(directNodeNameType, opDeviceSupport));

    if (GetCheckResultByDeviceOrder(opDeviceSupport, option, directNodeNameType, chkResult) == hiai::SUCCESS) {
        isRollBackCpu = false;
        return hiai::SUCCESS;
    }

    HIAI_EXPECT_EXEC(TryRollBackToCpu(
        static_cast<hiai::FallBackMode>(option.modelDeviceConfig.fallBackMode), opDeviceSupport, chkResult));
    isRollBackCpu = true;

    return hiai::SUCCESS;
}
} // namespace

Status ModelCompatibilityCheck::CheckIRGraphCompatibility(const ComputeGraphPtr& irGraph,
    const hiai::ModelCompileOptions& option, bool& isRollBackCpu, OpNameDevicesMap& chkResult)
{
    HIAI_EXPECT_NOT_NULL_R(irGraph, hiai::FAILURE);

    OpNameDevicesMap opDeviceSupport;
    HIAI_EXPECT_EXEC(GetIRGraphSupportResult(irGraph, opDeviceSupport));

    OpNameTypeMap directNodeNameType;
    (void)irGraph->ROLE(GraphListWalker).WalkAllNodes([&](ge::Node& node) {
        directNodeNameType.insert({node.ROLE(NodeSpec).Name(), node.ROLE(NodeSpec).Type()});
        return hiai::SUCCESS;
    });

    return GetCheckResult(opDeviceSupport, option, directNodeNameType, isRollBackCpu, chkResult);
}

namespace {
Status GetIRGraphSupportResultInSpecialCl(
    const ComputeGraphPtr& irGraph, const string& clName, OpNameDevicesMap& chkResult)
{
    std::shared_ptr<OpsKernelInfoStore> opKernel = OpKernelStoreManager::GetInstance()->GetOpsKernelInfoStore(clName);
    if (opKernel == nullptr) {
        FMK_LOGW("get opKernel of name %s failed!", clName.c_str());
        return SUCCESS;
    }

    for (const string& opName : opKernel->CheckSupported(irGraph)) {
#if defined AI_SUPPORT_DNNACL || defined AI_SUPPORT_ISPNN
        static const map<string, hiai::ExecuteDevice> deviceMap = {
            {NPU_CL, hiai::ExecuteDevice::NPU}, {DNNA_CL, hiai::ExecuteDevice::NPU},
            {NPUCL_ISPNN, hiai::ExecuteDevice::NPU}, {GPU_CL, hiai::ExecuteDevice::GPU}};
#else
        static const map<string, hiai::ExecuteDevice> deviceMap = {
            {NPU_CL, hiai::ExecuteDevice::NPU}, {GPU_CL, hiai::ExecuteDevice::GPU}};
#endif
        map<string, hiai::ExecuteDevice>::const_iterator deviceIt = deviceMap.find(clName);
        hiai::ExecuteDevice type = deviceIt != deviceMap.end() ? deviceIt->second : hiai::ExecuteDevice::CPU;
        chkResult[opName].push_back(type);
    }

    return SUCCESS;
}
} // namespace

Status ModelCompatibilityCheck::GetIRGraphSupportResult(const ComputeGraphPtr& irGraph, OpNameDevicesMap& chkResult)
{
    HIAI_EXPECT_NOT_NULL_R(irGraph, hiai::FAILURE);

    const std::set<std::string> clNames = OpKernelStoreManager::GetInstance()->GetLogicCLName();
    for (const string& clName : clNames) {
        if (GetIRGraphSupportResultInSpecialCl(irGraph, clName, chkResult) != SUCCESS) {
            FMK_LOGE("get ir model check result of clName: %s failed", clName.c_str());
            return hiai::FAILURE;
        }
    }

    return SUCCESS;
}
} // namespace ge
